{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# check whether run in Colab\n",
    "if 'google.colab' in sys.modules:\n",
    "    print('Running in Colab.')\n",
    "    !git clone https://github.com/PriceWang/MAECG.git\n",
    "    !pip install timm==0.4.12 wfdb==4.1.0 neurokit2==0.2.3\n",
    "    sys.path.append('./MAECG')\n",
    "\n",
    "import torch\n",
    "from vit_mae import mae_vit_base_patch32\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "    \n",
    "!wget -P ./datasets https://huggingface.co/datasets/PriceWang/dataset/resolve/main/maecg/CINC2021_ul.pth\n",
    "!wget -P ./datasets https://huggingface.co/datasets/PriceWang/dataset/resolve/main/maecg/MITDB_af_test.pth\n",
    "!wget -P ./datasets https://huggingface.co/datasets/PriceWang/dataset/resolve/main/maecg/ECGIDDB_dn_test.pth\n",
    "!wget -P ./ckpts https://huggingface.co/PriceWang/model/resolve/main/maecg/ViT-Base-DN.pth\n",
    "!wget -P ./ckpts https://huggingface.co/PriceWang/model/resolve/main/maecg/ViT-Base.pth"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recon(model, data, mask_ratio, patch_size):\n",
    "    latent, mask, ids_restore = model.forward_encoder(\n",
    "        data.unsqueeze(0).unsqueeze(0), mask_ratio\n",
    "    )\n",
    "    pred = model.forward_decoder(latent, ids_restore)\n",
    "    masked_input = data.squeeze() * (\n",
    "        1 - mask.squeeze().repeat_interleave(patch_size)\n",
    "    )\n",
    "    masked_input = masked_input.numpy()\n",
    "    temp_mask = masked_input == 0\n",
    "    plot_recon = np.copy(masked_input)\n",
    "    plot_masked = np.ma.masked_where(masked_input == 0, masked_input)\n",
    "    pred_unpatchified = model.unpatchify(pred).squeeze()\n",
    "    plot_recon[temp_mask] = pred_unpatchified.detach().numpy()[temp_mask]\n",
    "    lines = [\n",
    "        i * patch_size for i, x in enumerate(1 - mask.squeeze()) if x == 0\n",
    "    ] + [(i + 1) * patch_size for i, x in enumerate(1 - mask.squeeze()) if x == 0]\n",
    "    return masked_input, data, plot_masked, plot_recon, lines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reconstruction Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_one_recon(axes, col, masked_input, plot_ori, plot_masked, plot_recon, lines):\n",
    "    axes[0][col].plot(plot_ori, \"darkblue\", label=\"Original\")\n",
    "    axes[0][col].fill_between(\n",
    "        x=range(len(masked_input)),\n",
    "        y1=-1.1,\n",
    "        y2=1.1,\n",
    "        color=\"none\",\n",
    "        hatch=\"//\",\n",
    "        edgecolor=\"black\",\n",
    "        linewidth=0.0,\n",
    "        where=masked_input == 0,\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    axes[0][col].vlines(\n",
    "        x=list(set(lines)), ymin=-1.1, ymax=1.1, color=\"black\", alpha=0.1\n",
    "    )\n",
    "    axes[0][col].set_xlim([-30, 510])\n",
    "    axes[0][col].set_ylim([-1.1, 1.1])\n",
    "    axes[0][col].set_xticks([])\n",
    "    axes[0][col].set_yticks([])\n",
    "\n",
    "    axes[1][col].plot(plot_masked, \"k\", label=\"Masked\")\n",
    "    axes[1][col].fill_between(\n",
    "        x=range(len(masked_input)),\n",
    "        y1=-1.1,\n",
    "        y2=1.1,\n",
    "        color=\"none\",\n",
    "        hatch=\"//\",\n",
    "        edgecolor=\"black\",\n",
    "        linewidth=0.0,\n",
    "        where=masked_input == 0,\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    axes[1][col].vlines(\n",
    "        x=list(set(lines)), ymin=-1.1, ymax=1.1, color=\"black\", alpha=0.1\n",
    "    )\n",
    "    axes[1][col].set_xlim([-30, 510])\n",
    "    axes[1][col].set_ylim([-1.1, 1.1])\n",
    "    axes[1][col].set_xticks([])\n",
    "    axes[1][col].set_yticks([])\n",
    "\n",
    "    axes[2][col].plot(plot_recon, \"darkred\", label=\"Reconstructed\")\n",
    "    axes[2][col].fill_between(\n",
    "        x=range(len(masked_input)),\n",
    "        y1=-1.1,\n",
    "        y2=1.1,\n",
    "        color=\"none\",\n",
    "        hatch=\"//\",\n",
    "        edgecolor=\"black\",\n",
    "        linewidth=0.0,\n",
    "        where=masked_input == 0,\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    axes[2][col].vlines(\n",
    "        x=list(set(lines)), ymin=-1.1, ymax=1.1, color=\"black\", alpha=0.1\n",
    "    )\n",
    "    axes[2][col].set_xlim([-30, 510])\n",
    "    axes[2][col].set_ylim([-1.1, 1.1])\n",
    "    axes[2][col].set_xticks([])\n",
    "    axes[2][col].set_yticks([])\n",
    "\n",
    "    axes[3][col].plot(plot_ori, \"darkblue\", label=\"Original\")\n",
    "    axes[3][col].plot(plot_recon, \"darkred\", label=\"Reconstructed\")\n",
    "    axes[3][col].fill_between(\n",
    "        x=range(len(masked_input)),\n",
    "        y1=-1.1,\n",
    "        y2=1.1,\n",
    "        color=\"none\",\n",
    "        hatch=\"//\",\n",
    "        edgecolor=\"black\",\n",
    "        linewidth=0.0,\n",
    "        where=masked_input == 0,\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    axes[3][col].vlines(\n",
    "        x=list(set(lines)), ymin=-1.1, ymax=1.1, color=\"black\", alpha=0.1\n",
    "    )\n",
    "    axes[3][col].set_xlim([-30, 510])\n",
    "    axes[3][col].set_ylim([-1.1, 1.1])\n",
    "    axes[3][col].set_xticks([])\n",
    "    axes[3][col].set_yticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = mae_vit_base_patch32(signal_length=480)\n",
    "checkpoint = torch.load(\n",
    "    \"./ckpts/ViT-Base.pth\", map_location=\"cpu\"\n",
    ")\n",
    "model.load_state_dict(checkpoint[\"model\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Same dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torch.load(\"./datasets/CINC2021_ul.pth\")\n",
    "index = [random.randint(0, len(dataset) - 1) for _ in range(5)]\n",
    "\n",
    "fig = plt.figure(figsize=(25, 5))\n",
    "axes = fig.subplots(nrows=4, ncols=5)\n",
    "for i in range(5):\n",
    "    masked_input, plot_ori, plot_masked, plot_recon, lines = recon(\n",
    "        model, dataset[index[i]], 0.5, 32\n",
    "    )\n",
    "    plot_one_recon(axes, i, masked_input, plot_ori, plot_masked, plot_recon, lines)\n",
    "lines = []\n",
    "labels = []\n",
    "axLine, axLabel = axes[0][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "axLine, axLabel = axes[1][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "axLine, axLabel = axes[2][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "\n",
    "lgd = fig.legend(\n",
    "    lines,\n",
    "    labels,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, -0.03),\n",
    "    ncol=3,\n",
    "    prop={\"size\": 18},\n",
    ")\n",
    "fig.subplots_adjust(hspace=0, wspace=0.05)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Different Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset = torch.load(\"./datasets/MITDB_af_test.pth\")\n",
    "index = [random.randint(0, len(dataset) - 1) for _ in range(5)]\n",
    "\n",
    "fig = plt.figure(figsize=(25, 5))\n",
    "axes = fig.subplots(nrows=4, ncols=5)\n",
    "for i in range(5):\n",
    "    masked_input, plot_ori, plot_masked, plot_recon, lines = recon(\n",
    "        model, dataset[index[i]][0], 0.5, 32\n",
    "    )\n",
    "    plot_one_recon(axes, i, masked_input, plot_ori, plot_masked, plot_recon, lines)\n",
    "lines = []\n",
    "labels = []\n",
    "axLine, axLabel = axes[0][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "axLine, axLabel = axes[1][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "axLine, axLabel = axes[2][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "\n",
    "lgd = fig.legend(\n",
    "    lines,\n",
    "    labels,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, -0.03),\n",
    "    ncol=3,\n",
    "    prop={\"size\": 18},\n",
    ")\n",
    "fig.subplots_adjust(hspace=0, wspace=0.05)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Denoise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def denoise(model, dataset, index):\n",
    "    wn = dataset.signals_wn[index]\n",
    "    won = dataset.signals_won[index]\n",
    "    latent, _, ids_restore = model.forward_encoder(wn.unsqueeze(0).unsqueeze(0), 0)\n",
    "    pred = model.forward_decoder(latent, ids_restore)\n",
    "    pred_unpatchified = model.unpatchify(pred).squeeze()\n",
    "    plot_denoise = pred_unpatchified.detach().numpy()\n",
    "    return won, wn, plot_denoise\n",
    "\n",
    "\n",
    "def plot_one_denoise(axes, col, plot_target, plot_input, plot_denoise):\n",
    "    axes[0][col].plot(\n",
    "        plot_target, \"darkblue\", label=\"Target\".center(20) + \"\\n\" \"(Without Noise)\"\n",
    "    )\n",
    "    axes[0][col].set_xlim([-30, 510])\n",
    "    axes[0][col].set_ylim([-1.1, 1.1])\n",
    "    axes[0][col].set_xticks([])\n",
    "    axes[0][col].set_yticks([])\n",
    "    axes[1][col].plot(plot_input, \"k\", label=\"Input\".center(16) + \"\\n\" + \"(With Noise)\")\n",
    "    axes[1][col].set_xlim([-30, 510])\n",
    "    axes[1][col].set_ylim([-1.1, 1.1])\n",
    "    axes[1][col].set_xticks([])\n",
    "    axes[1][col].set_yticks([])\n",
    "    axes[2][col].plot(\n",
    "        plot_denoise, \"darkred\", label=\"Output\".center(12) + \"\\n\" + \"(Denoised)\"\n",
    "    )\n",
    "    axes[2][col].set_xlim([-30, 510])\n",
    "    axes[2][col].set_ylim([-1.1, 1.1])\n",
    "    axes[2][col].set_xticks([])\n",
    "    axes[2][col].set_yticks([])\n",
    "    axes[3][col].plot(\n",
    "        plot_target, \"darkblue\", label=\"Target\".center(20) + \"\\n\" \"(Without Noise)\"\n",
    "    )\n",
    "    axes[3][col].plot(\n",
    "        plot_denoise, \"darkred\", label=\"Output\".center(12) + \"\\n\" + \"(Denoised)\"\n",
    "    )\n",
    "    axes[3][col].set_xlim([-30, 510])\n",
    "    axes[3][col].set_ylim([-1.1, 1.1])\n",
    "    axes[3][col].set_xticks([])\n",
    "    axes[3][col].set_yticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = mae_vit_base_patch32(signal_length=480)\n",
    "ckpt = torch.load(\n",
    "    \"./ckpts/ViT-Base-DN.pth\", map_location=\"cpu\"\n",
    ")\n",
    "model.load_state_dict(ckpt[\"model\"])\n",
    "dataset = torch.load(\"./datasets/ECGIDDB_dn_test.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = [random.randint(0, len(dataset) - 1) for _ in range(4)]\n",
    "fig = plt.figure(figsize=(20, 5))\n",
    "axes = fig.subplots(nrows=4, ncols=4)\n",
    "for i in range(4):\n",
    "    plot_target, plot_input, plot_denoise = denoise(model, dataset, index[i])\n",
    "    plot_one_denoise(axes, i, plot_target, plot_input, plot_denoise)\n",
    "lines = []\n",
    "labels = []\n",
    "axLine, axLabel = axes[0][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "axLine, axLabel = axes[1][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "axLine, axLabel = axes[2][0].get_legend_handles_labels()\n",
    "lines.extend(axLine)\n",
    "labels.extend(axLabel)\n",
    "lgd = fig.legend(\n",
    "    lines,\n",
    "    labels,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, -0.075),\n",
    "    ncol=3,\n",
    "    prop={\"size\": 18},\n",
    ")\n",
    "fig.subplots_adjust(hspace=0, wspace=0.05)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
